In [1]:
##
from __future__ import annotations
In [2]:
import os
os.chdir('/data/l989o/deployed/a')
In [3]:
import sys
import shutil
import scvi
import scanpy as sc
import torch
from torch.utils.data import DataLoader
In [4]:
from data2 import SumFilteredDataset, file_path
import numpy as np
from tqdm import tqdm
import anndata as ad
import pandas as pd
from pytorch_lightning.loggers import TensorBoardLogger
import os
import matplotlib.pyplot as plt
from utils import memory, reproducible_random_choice
done
In [5]:
# COMPLETE_RUN = True
COMPLETE_RUN = True
N_EPOCHS_KL_WARMUP = 3
N_EPOCHS = 10
In [6]:
m = __name__ == '__main__'
In [7]:
##
if m and False:
    # proxy for the DKFZ network
    # https://stackoverflow.com/questions/34576665/setting-proxy-to-urllib-request-python3
    os.environ["HTTP_PROXY"] = "http://193.174.53.86:80"
    os.environ["HTTPS_PROXY"] = "https://193.174.53.86:80"
In [8]:
##
if m and False:
    # to have a look at an existing dataset
    import scvi.data

    data = scvi.data.pbmc_dataset()
    data
In [9]:
##
if m:
    ds = SumFilteredDataset("train")


    @memory.cache
    def f_qpoxnqwida(ds):
        l0 = []
        l1 = []
        for i, x in enumerate(tqdm(ds, "merging")):
            l0.append(x)
            l1.extend([i] * len(x))
        return l0, l1


    l0, l1 = f_qpoxnqwida(ds)
    raw = np.concatenate(l0, axis=0)
    raw = np.round(raw)
    raw = raw.astype(np.int)
    donor = np.array(l1)
    a = ad.AnnData(raw)
merging:   1%|          | 2/226 [00:00<00:20, 10.67it/s]
________________________________________________________________________________
[Memory] Calling __main__--data-l989o-deployed-a-<ipython-input-bb8ce796987d>.f_qpoxnqwida...
f_qpoxnqwida(<data2.SumFilteredDataset object at 0x7eff6042e1f0>)
merging: 100%|██████████| 226/226 [00:18<00:00, 12.36it/s]
____________________________________________________f_qpoxnqwida - 19.0s, 0.3min
In [10]:
##
if m:
    s = pd.Series(donor, index=a.obs.index)
    a.obs["batch"] = s
In [11]:
##
if m:
    scvi.data.setup_anndata(
        a,
        # this is probably meaningless (if not even penalizing) for unseen data as the batches are different
        # categorical_covariate_keys=["batch"],
    )
    a
INFO     No batch_key inputted, assuming all cells are same batch                            
INFO     No label_key inputted, assuming all cells have same label                           
INFO     Using data from adata.X                                                             
INFO     Computing library size prior per batch                                              
INFO     Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 
         1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
INFO     Please do not further modify adata until model is trained.                          
In [12]:
##
if m:
    # TRAIN = True
    TRAIN = False
    if TRAIN:
        # vae = VAE(gene_dataset.nb_genes)
        # trainer = UnsupervisedTrainer(
        #     vae,
        #     gene_dataset,
        #     train_size=0.90,
        #     use_cuda=use_cuda,
        #     frequency=5,
        # )
        # []:
        # trainer.train(n_epochs=n_epochs, lr=lr)
        model = scvi.model.SCVI(a)
In [13]:
##
from data2 import file_path
In [14]:
# the following code, as it is, doesn't work
#     logger = TensorBoardLogger(save_dir=file_path("checkpoints"), name="scvi")
#     BATCH_SIZE = 128
#     indices = np.random.choice(len(a), BATCH_SIZE * 20, replace=False)
#
#     train_loader_batch = DataLoader(
#         a.X[indices, :],
#         batch_size=BATCH_SIZE,
#         num_workers=4,
#         pin_memory=True,
#     )
#     model.train(train_size=1., logger=logger, val_dataloaders=train_loader_batch)
#     model.__dict__
if m:
    if TRAIN:
        model.train(train_size=1.0, n_epochs=N_EPOCHS, n_epochs_kl_warmup=N_EPOCHS_KL_WARMUP)
        f = file_path("scvi_model.scvi")
        if os.path.isdir(f):
            shutil.rmtree(f)
        model.save(f)
    else:
        model = scvi.model.SCVI.load(file_path("scvi_model.scvi"), adata=a)
    print(model.get_elbo())
INFO     Using data from adata.X                                                             
INFO     Computing library size prior per batch                                              
INFO     Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels']     
INFO     Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 
         1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
-187.95363649203713
In [15]:
##
if m:
    z = model.get_latent_representation()
    a.shape
    z.shape
    b = ad.AnnData(z)
    random_indices = reproducible_random_choice(len(a), 10000)
    aa = a[random_indices]
    bb = b[random_indices]
In [16]:
def scanpy_compute(an: ad.AnnData):
    sc.tl.pca(an)
    print("computing neighbors... ", end="")
    sc.pp.neighbors(an)
    print("done")
    print("computing umap... ", end="")
    sc.tl.umap(an)
    print("done")
    print("computing louvain... ", end="")
    sc.tl.louvain(an)
    print("done")
In [17]:
##
if m and COMPLETE_RUN:
    scanpy_compute(aa)
    sc.pl.pca(aa, title="pca, raw data (sum)")
    sc.pl.umap(aa, color="louvain", title="umap with louvain, scvi latent (sum)")
computing neighbors... done
computing umap... done
computing louvain... done
In [18]:
##
if m and COMPLETE_RUN:
    scanpy_compute(bb)
    sc.pl.pca(bb, title="pca, raw data (sum)")
    sc.pl.umap(bb, color="louvain", title="umap with louvain, scvi latent (sum)")
computing neighbors... done
computing umap... done
computing louvain... done
In [19]:
##
from analyses.ab_images_vs_expression.ab_aa_expression_latent_samples import (
    compare_clusters,
    nearest_neighbors,
    compute_knn,
)
Global seed set to 1234
In [20]:
if m and COMPLETE_RUN:
    compare_clusters(aa, bb, description='"raw data (sum)" vs "scvi latent"')
    compute_knn(aa)
    compute_knn(bb)
    nearest_neighbors(
        nn_from=aa, plot_onto=bb, title='nn from "raw data (sum)" to "scvi latent"'
    )
(19, 18)
done
done
100%|██████████| 5/5 [00:00<00:00, 347.26it/s]
In [21]:
##
if m:
    ds = SumFilteredDataset("validation")


    @memory.cache
    def f_ncqoi3faoj(ds):
        l0 = []
        l1 = []
        for i, x in enumerate(tqdm(ds, "merging")):
            l0.append(x)
            l1.extend([i] * len(x))
        return l0, l1


    l0, l1 = f_ncqoi3faoj(ds)
    raw = np.concatenate(l0, axis=0)
    donor = np.array(l1)
    a_val = ad.AnnData(raw)
merging:   1%|          | 1/113 [00:00<00:12,  8.79it/s]
________________________________________________________________________________
[Memory] Calling __main__--data-l989o-deployed-a-<ipython-input-7f457070833c>.f_ncqoi3faoj...
f_ncqoi3faoj(<data2.SumFilteredDataset object at 0x7efe7afe3d30>)
merging: 100%|██████████| 113/113 [00:08<00:00, 12.59it/s]
_____________________________________________________f_ncqoi3faoj - 9.3s, 0.2min
In [22]:
##
if m:
    # note that here with are embedding without the batch information; if you want to look at batches it does not make
    # sense to use another set except to the training one, since the train/val/test split is done by patient first
    scvi.data.setup_anndata(
        a_val,
    )
    z_val = model.get_latent_representation(a_val)
    b_val = ad.AnnData(z_val)
    random_indices_val = reproducible_random_choice(len(a_val), 10000)
    aa_val = a_val[random_indices_val]
    bb_val = b_val[random_indices_val]
INFO     No batch_key inputted, assuming all cells are same batch                            
INFO     No label_key inputted, assuming all cells have same label                           
INFO     Using data from adata.X                                                             
INFO     Computing library size prior per batch                                              
INFO     Successfully registered anndata object containing 219095 cells, 39 vars, 1 batches, 
         1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
INFO     Please do not further modify adata until model is trained.                          
WARNING  Make sure the registered X field in anndata contains unnormalized count data.       
/data/l989o/miniconda3/envs/spatial_uzh/lib/python3.8/site-packages/scvi/data/_anndata.py:793: UserWarning: adata.X does not contain unnormalized count data. Are you sure this is what you want?
  warnings.warn(
In [23]:
##
if m and COMPLETE_RUN:
    scanpy_compute(aa_val)
    scanpy_compute(bb_val)
computing neighbors... done
computing umap... done
computing louvain... done
computing neighbors... done
computing umap... done
computing louvain... done
In [24]:
##
if m and COMPLETE_RUN:
    sc.pl.pca(aa_val, title="pca, raw data (sum); validation set")
    sc.pl.umap(
        aa_val, color="louvain", title="umap with louvain, scvi latent (sum); valiation set"
    )
    sc.pl.pca(bb_val, title="pca, raw data (sum); valiation set")
    sc.pl.umap(
        bb_val, color="louvain", title="umap with louvain, scvi latent (sum); valiation set"
    )
In [25]:
##
if m and COMPLETE_RUN:
    merged = ad.AnnData.concatenate(bb, bb_val, batch_categories=["train", "validation"])
    scanpy_compute(merged)
    plt.figure()
    ax = plt.gca()
    sc.pl.umap(merged, color="batch", ax=ax, show=False)
    plt.tight_layout()
    plt.show()
computing neighbors... done
computing umap... done
computing louvain... done
In [26]:
##
if m:
    size_factors = model.get_latent_library_size(a_val)
WARNING  Make sure the registered X field in anndata contains unnormalized count data.       
In [27]:
##
from data2 import AreaFilteredDataset
In [28]:
if m and COMPLETE_RUN:
    area_ds = AreaFilteredDataset("validation")

    l = []
    for x in tqdm(area_ds, desc="merging"):
        l.append(x)
    areas = np.concatenate(l, axis=0)
merging: 100%|██████████| 113/113 [00:08<00:00, 12.90it/s]
In [29]:
##
if m and COMPLETE_RUN:
    from scipy.stats import pearsonr

    print(size_factors.shape)
    print(areas.shape)
    r, p = pearsonr(size_factors.ravel(), areas.ravel())
    plt.figure()
    plt.scatter(size_factors, areas, s=0.5)
    plt.xlabel("latent size factors")
    plt.ylabel("cell area")
    plt.title(f"r: {r:0.2f} (p: {p:0.2f})")
    plt.show()
(219095, 1)
(219095, 1)
In [30]:
##
# imputation benchmark
from data2 import PerturbedCellDataset
In [31]:
def get_corrupted_entries(split: str):
    ds = PerturbedCellDataset(split)
    ds.perturb()
    corrupted_entries = ds.corrupted_entries.numpy()
    # just a hash
    h = np.sum(np.concatenate(np.where(corrupted_entries == 1)))
    print(f"corrupted entries hash ({split}):", h)
    return corrupted_entries
In [32]:
if m:
    ce_train = get_corrupted_entries("train")
    ce_val = get_corrupted_entries("validation")
corrupted entries hash (train): 389620020511
corrupted entries hash (validation): 93848327662
In [33]:
##
if m:
    ds = SumFilteredDataset("train")


    @memory.cache
    def f_ncqlliwr2(ds):
        l0 = []
        for i, x in enumerate(tqdm(ds, "merging")):
            l0.append(x)
        return l0


    l0 = f_ncqlliwr2(ds)
    raw = np.concatenate(l0, axis=0)
    raw[ce_train] = 0
    raw = np.round(raw)
    raw = raw.astype(np.int)
    a_perturbed = ad.AnnData(raw)
merging:   1%|          | 2/226 [00:00<00:20, 11.10it/s]
________________________________________________________________________________
[Memory] Calling __main__--data-l989o-deployed-a-<ipython-input-2d8cc6047e99>.f_ncqlliwr2...
f_ncqlliwr2(<data2.SumFilteredDataset object at 0x7eff54750b20>)
merging: 100%|██████████| 226/226 [00:17<00:00, 12.95it/s]
_____________________________________________________f_ncqlliwr2 - 17.6s, 0.3min
In [34]:
##
if m:
    scvi.data.setup_anndata(a_perturbed)
    # TRAIN_PERTURBED = True
    TRAIN_PERTURBED = False
    if TRAIN_PERTURBED:
        # to navigate there with PyCharm and set a breakpoint on a warning (haven't done yet)
        import scvi.core.distributions

        model = scvi.model.SCVI(a_perturbed)
    if TRAIN_PERTURBED:
        model.train(train_size=1.0, n_epochs=N_EPOCHS, n_epochs_kl_warmup=N_EPOCHS_KL_WARMUP)
        f = file_path("scvi_model_perturbed.scvi")
        if os.path.isdir(f):
            shutil.rmtree(f)
        model.save(f)
    else:
        model = scvi.model.SCVI.load(file_path("scvi_model_perturbed.scvi"), adata=a)
    print(model.get_elbo())
INFO     No batch_key inputted, assuming all cells are same batch                            
INFO     No label_key inputted, assuming all cells have same label                           
INFO     Using data from adata.X                                                             
INFO     Computing library size prior per batch                                              
INFO     Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 
         1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
INFO     Please do not further modify adata until model is trained.                          
INFO     Using data from adata.X                                                             
INFO     Computing library size prior per batch                                              
INFO     Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels']     
INFO     Successfully registered anndata object containing 446738 cells, 39 vars, 1 batches, 
         1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
-195.80324340944367
In [35]:
##
if m:
    x_val_perturbed = a_val.X.copy()
    x_val_perturbed[ce_val] = 0
    a_val_perturbed = ad.AnnData(x_val_perturbed)
In [36]:
##
if m:
    p = model.get_likelihood_parameters(a_val_perturbed)
    from scvi.core.distributions import ZeroInflatedNegativeBinomial

    x_val_perturbed_pred = ZeroInflatedNegativeBinomial(
        mu=torch.tensor(p["mean"]),
        theta=torch.tensor(p["dispersions"]),
        zi_logits=torch.tensor(p["dropout"]),
    ).mean.numpy()
INFO     Input adata not setup with scvi. attempting to transfer anndata setup               
INFO     .obs[_scvi_batch] not found in target, assuming every cell is same category         
INFO     .obs[_scvi_labels] not found in target, assuming every cell is same category        
INFO     Using data from adata.X                                                             
INFO     Computing library size prior per batch                                              
INFO     Registered keys:['X', 'batch_indices', 'local_l_mean', 'local_l_var', 'labels']     
INFO     Successfully registered anndata object containing 219095 cells, 39 vars, 1 batches, 
         1 labels, and 0 proteins. Also registered 0 extra categorical covariates and 0 extra
         continuous covariates.                                                              
WARNING  Make sure the registered X field in anndata contains unnormalized count data.       
/data/l989o/miniconda3/envs/spatial_uzh/lib/python3.8/site-packages/scvi/data/_anndata.py:793: UserWarning: adata.X does not contain unnormalized count data. Are you sure this is what you want?
  warnings.warn(
In [37]:
##
if m:
    # ne: normal entries
    ne_train = np.logical_not(ce_train)
    ne_val = np.logical_not(ce_val)
    x_val = a_val.X.copy()

    uu0 = x_val_perturbed_pred[ce_val]
    uu1 = x_val[ce_val]

    vv0 = x_val_perturbed_pred[ne_val]
    vv1 = x_val[ne_val]
##
if m:
    fig = plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.hist(np.abs(uu0 - uu1))
    m = np.mean(np.abs(uu0 - uu1))
    plt.title(f"scores for imputed entries\nmean: {m:0.2f}")
    plt.yscale("log")

    plt.subplot(1, 2, 2)
    plt.hist(np.abs(vv0 - vv1))
    m = np.mean(np.abs(vv0 - vv1))
    plt.title(f"control: normal entries\nmean: {m:0.2f}")
    plt.yscale("log")

    fig.suptitle("abs(original vs predicted)")
    plt.tight_layout()
    plt.show()
In [38]:
##
if m:
    from analyses.aa_reconstruction_benchmark.aa_ad_reconstruction import Prediction, Space

    s = np.abs(uu0 - uu1)
    t = np.abs(vv0 - vv1)
    Prediction.welch_t_test(s, t)
    # the printed p-value is very close to 0
    # conclusion: the score for imputed data is worse than the one from non-perturbed data; this is expected and the
    # alternative case would have been a model whose scores are both bad because it is not properly trained
##
welch's t test: p_value = 0.0
In [39]:
if m:
    scvi_predictions = Prediction(
        original=x_val,
        corrupted_entries=ce_val,
        predictions_from_perturbed=x_val_perturbed_pred,
        space=Space.raw_sum,
        name='scVI',
        split='validation'
    )

    scvi_predictions.plot_reconstruction()
    scvi_predictions.plot_scores()
channels: 100%|██████████| 39/39 [00:34<00:00,  1.12it/s]
In [40]:
##
if m:
    p = scvi_predictions.transform_to(Space.scaled_mean)
    p.name = 'scVI scaled'
    p.plot_reconstruction()
    p.plot_scores()
applying transformation from raw_sum to raw_mean
applying transformation from raw_mean to asinh_mean
applying transformation from asinh_mean to scaled_mean
applying transformation from raw_sum to raw_mean
applying transformation from raw_mean to asinh_mean
applying transformation from asinh_mean to scaled_mean
channels: 100%|██████████| 39/39 [00:36<00:00,  1.07it/s]